import sys
import json
import numpy as np

from tqdm import tqdm
from pathlib import Path
from sklearn.metrics import accuracy_score


from scorers.automatic import AutomaticScorer


def calculate_metrics(score_item, clinical_department, inference_dataset, guide3_dataset, guide5_dataset, gd_dataset,
                      language):
    score_item[clinical_department] = {}
    score_item[clinical_department]['metrics'] = {}

    scorer = AutomaticScorer()
    # ====================================================
    # 计算导诊任务分数
    # ====================================================
    predictions = [0] * len(inference_dataset)
    references = [1] * len(inference_dataset)
    for index in tqdm(range(len(inference_dataset))):
        predicted_clinical_department = inference_dataset[index]['predicted_clinical_department']
        ground_truth_clinical_department = inference_dataset[index]['clinical_department']
        # 重复出现了两个以上直接错误
        hit_count = 0
        for _ in clinical_department_zh_list:
            if _ in predicted_clinical_department or predicted_clinical_department in _:
                hit_count += 1
        if (hit_count > 1):
            print(str(hit_count) + str(f", {inference_dataset[index]['clinical_case_uid']}"))
        if hit_count == 1 and (
                ground_truth_clinical_department in predicted_clinical_department or predicted_clinical_department in ground_truth_clinical_department):
            predictions[index] = 1
    # 导诊子任务指标
    score_item[clinical_department]['metrics']['guide_departmental_accuracy'] = {}
    accuracy = accuracy_score(y_true=references, y_pred=predictions) * 100
    score_item[clinical_department]['metrics']['guide_departmental_accuracy']['accuracy'] = accuracy
    # 导诊子任务分数
    score_item[clinical_department]['guide_departmental_accuracy'] = round(
        score_item[clinical_department]['metrics']['guide_departmental_accuracy']['accuracy'], 2)

    # 导诊子任务指标
    score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'] = {}
    for k in [3, 5]:
        if k == 3:
            guide_dataset = guide3_dataset
        else:
            guide_dataset = guide5_dataset
        guidek_departmental_quantity_following_clinical_cases = 0
        guidek_departmental_name_following_clinical_cases = 0
        guidek_total_clinical_cases = len(guide_dataset)
        for item in guide_dataset:
            predicted_clinical_department_list = item['predicted_clinical_department']
            if (len(predicted_clinical_department_list) == k):
                guidek_departmental_quantity_following_clinical_cases += 1
            flag = True
            for predicted_clinical_department in predicted_clinical_department_list:
                if predicted_clinical_department not in clinical_department_zh_list:
                    flag = False
                    break
            if flag:
                guidek_departmental_name_following_clinical_cases += 1
        guidek_departmental_quantity_following_rate = ( guidek_departmental_quantity_following_clinical_cases / guidek_total_clinical_cases) * 100
        guidek_departmental_name_following_rate = ( guidek_departmental_name_following_clinical_cases / guidek_total_clinical_cases) * 100
        guidek_departmental_instruction_following_rate = np.mean(
            [guidek_departmental_quantity_following_rate, guidek_departmental_name_following_rate])
        score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
            f'guide{k}_departmental_quantity_following_rate'] = guidek_departmental_quantity_following_rate
        score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
            f'guide{k}_departmental_name_following_rate'] = guidek_departmental_name_following_rate
        score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
            f'guide{k}_departmental_instruction_following_rate'] = guidek_departmental_instruction_following_rate
    score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
        'guide_departmental_quantity_following_rate'] = np.mean([
        score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
            f'guide{k}_departmental_quantity_following_rate'] for k in [3, 5]
    ])
    score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
        'guide_departmental_name_following_rate'] = np.mean([
        score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
            f'guide{k}_departmental_name_following_rate'] for k in [3, 5]
    ])
    score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
        'guide_departmental_instruction_following_rate'] = np.mean([
        score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
            f'guide{k}_departmental_instruction_following_rate'] for k in [3, 5]
    ])
    # 导诊子任务分数
    score_item[clinical_department]['guide_departmental_instruction_following_rate'] = round(float(np.mean([
        score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
            'guide_departmental_quantity_following_rate'],
        score_item[clinical_department]['metrics']['guide_departmental_instruction_following_rate'][
            'guide_departmental_name_following_rate']
    ])), 2)

    # 导诊任务总分
    score_item[clinical_department]['hospital_guide'] = round(
        float(np.mean([
            score_item[clinical_department]['guide_departmental_accuracy'],
            score_item[clinical_department]['guide_departmental_instruction_following_rate']
        ])), 2)

    # ====================================================
    # 计算临床诊断任务分数
    # ====================================================
    for clinical_diagnosis_part in ['preliminary_diagnosis', 'principal_diagnosis']:
        predictions = [0] * len(inference_dataset)
        references = [1] * len(inference_dataset)
        for index in tqdm(range(len(inference_dataset))):
            predicted_clinical_diagnosis_part = inference_dataset[index][f'predicted_{clinical_diagnosis_part}']
            ground_truth_clinical_diagnosis_part = gd_dataset[index][f'ground_truth_{clinical_diagnosis_part}']
            for disease_diagnosis in ground_truth_clinical_diagnosis_part:
                if disease_diagnosis in predicted_clinical_diagnosis_part:
                    predictions[index] = 1
                    break
        # 临床诊断子任务指标
        score_item[clinical_department]['metrics'][clinical_diagnosis_part] = {}
        accuracy = accuracy_score(y_true=references, y_pred=predictions) * 100
        score_item[clinical_department]['metrics'][clinical_diagnosis_part]['accuracy'] = accuracy
        # 临床诊断子任务分数
        score_item[clinical_department][clinical_diagnosis_part] = round(
            score_item[clinical_department]['metrics'][clinical_diagnosis_part]['accuracy'], 2)

    for clinical_diagnosis_part in ['diagnostic_basis', 'differential_diagnosis', 'therapeutic_principle',
                                    'treatment_plan']:
        predictions = []
        references = []
        for index in tqdm(range(len(inference_dataset))):
            predictions.append(inference_dataset[index][f'predicted_{clinical_diagnosis_part}'])
            references.append(inference_dataset[index][clinical_diagnosis_part])
        if len(predictions) != len(references):
            raise Exception(f'### [Unequal length]')
        # 临床诊断子任务指标
        score_item[clinical_department]['metrics'][clinical_diagnosis_part] = {}
        bleu = scorer.calculate_bleu(language, predictions, references)
        score_item[clinical_department]['metrics'][clinical_diagnosis_part]['bleu'] = bleu
        rouge = scorer.calculate_rouge(language, predictions, references)
        score_item[clinical_department]['metrics'][clinical_diagnosis_part]['rouge'] = rouge
        bertscore = scorer.calculate_bertscore(language, predictions, references)
        score_item[clinical_department]['metrics'][clinical_diagnosis_part]['bertscore'] = bertscore
        # 临床诊断子任务分数
        clinical_diagnosis_part_scores = []
        clinical_diagnosis_part_scores.append(
            score_item[clinical_department]['metrics'][clinical_diagnosis_part]['bleu'])
        clinical_diagnosis_part_scores.append(
            score_item[clinical_department]['metrics'][clinical_diagnosis_part]['rouge'])
        clinical_diagnosis_part_scores.append(
            score_item[clinical_department]['metrics'][clinical_diagnosis_part]['bertscore'])
        score_item[clinical_department][clinical_diagnosis_part] = round(float(np.mean(clinical_diagnosis_part_scores)),
                                                                         2)
    # 临床诊断任务总分
    score_item[clinical_department]['clinical_diagnosis'] = round(
        float(np.mean([score_item[clinical_department][clinical_diagnosis_part] for clinical_diagnosis_part in
                       clinical_diagnosis_part_list])), 2)

    # ====================================================
    # 计算影像诊断任务分数
    # ====================================================
    predictions = []
    references = []
    for index in tqdm(range(len(inference_dataset))):
        if isinstance(inference_dataset[index]['imageological_examination'], dict):
            for imageological_examination_part_feature in inference_dataset[index]['imageological_examination'].keys():
                predictions.append(
                    inference_dataset[index]['imageological_examination'][imageological_examination_part_feature][
                        'predicted_impression'])
                references.append(
                    inference_dataset[index]['imageological_examination'][imageological_examination_part_feature][
                        'impression'])
    if len(predictions) != len(references):
        raise Exception(f'### [Unequal length]')
    # 影像诊断任务指标
    score_item[clinical_department]['metrics']['imaging_diagnosis'] = {}
    bleu = scorer.calculate_bleu(language, predictions, references)
    score_item[clinical_department]['metrics']['imaging_diagnosis']['bleu'] = bleu
    rouge = scorer.calculate_rouge(language, predictions, references)
    score_item[clinical_department]['metrics']['imaging_diagnosis']['rouge'] = rouge
    bertscore = scorer.calculate_bertscore(language, predictions, references)
    score_item[clinical_department]['metrics']['imaging_diagnosis']['bertscore'] = bertscore
    # 影像诊断任务总分
    imaging_diagnosis_scores = []
    imaging_diagnosis_scores.append(score_item[clinical_department]['metrics']['imaging_diagnosis']['bleu'])
    imaging_diagnosis_scores.append(score_item[clinical_department]['metrics']['imaging_diagnosis']['rouge'])
    imaging_diagnosis_scores.append(score_item[clinical_department]['metrics']['imaging_diagnosis']['bertscore'])
    imaging_diagnosis_score = round(float(np.mean(imaging_diagnosis_scores)), 2)
    score_item[clinical_department]['imaging_diagnosis'] = imaging_diagnosis_score

    # ====================================================
    # 计算所有任务的加权平均分数
    # ====================================================
    weight_dict = {
        'guide_departmental_accuracy': 5,
        'guide_departmental_instruction_following_rate': 1,
        'preliminary_diagnosis': 1,
        'diagnostic_basis': 1,
        'differential_diagnosis': 1,
        'principal_diagnosis': 5,
        'therapeutic_principle': 1,
        'treatment_plan': 1,
        'imaging_diagnosis': 1
    }
    total_weighted_score = sum(score_item[clinical_department][key] * weight for key, weight in weight_dict.items())
    total_weights = sum(weight_dict.values())
    weighted_average = round((total_weighted_score / total_weights), 2)
    score_item[clinical_department]['average'] = round(
        float(np.mean([score_item[clinical_department][key] for key in weight_dict.keys()])), 2)
    score_item[clinical_department]['weighted_average'] = weighted_average

    return score_item


def main():
    gd_zh_load_path = data_dir / Path(gd_zh_load_name)
    with open(gd_zh_load_path, mode='r', encoding='utf-8') as file:
        gd_zh_dataset = json.load(file)
    # gd_en_load_path = data_dir / Path(gd_en_load_name)
    # with open(gd_en_load_path, mode='r', encoding='utf-8') as file:
    #     gd_en_dataset = json.load(file)
    gd_en_dataset = None

    score_dict = {}
    score_dict['code'] = 0
    score_dict['data'] = []
    for model_name in model_name_list:
        print(model_name)
        inference_load_name = f'inference_{language}_{model_name}.json'
        inference_load_path = inference_dir / Path(inference_load_name)
        with open(inference_load_path, mode='r', encoding='utf-8') as file:
            inference_dataset = json.load(file)

        guide3_load_name = f'guide_{language}_{model_name}_acc3.json'
        guide3_load_path = guide_dir / Path(guide3_load_name)
        with open(guide3_load_path, mode='r', encoding='utf-8') as file:
            guide3_dataset = json.load(file)

        guide5_load_name = f'guide_{language}_{model_name}_acc5.json'
        guide5_load_path = guide_dir / Path(guide5_load_name)
        with open(guide5_load_path, mode='r', encoding='utf-8') as file:
            guide5_dataset = json.load(file)

        score_item = {}
        score_item['model'] = model_name_mapping_dict[model_name]
        score_item['institution'] = institution_name_mapping_dict[model_name]
        score_item['url'] = institution_url_mapping_dict[model_name]

        # 分科室比较效果
        if (language == 'zh'):
            clinical_department_list = clinical_department_zh_list
            gd_dataset = gd_zh_dataset
        else:
            clinical_department_list = clinical_department_en_list
            gd_dataset = gd_en_dataset

        inference_clinical_cases_list = []
        guide3_clinical_cases_list = []
        guide5_clinical_cases_list = []
        for clinical_department in clinical_department_list:
            inference_clinical_cases_list.append(
                [item for item in inference_dataset if item['clinical_department'] == clinical_department])
            guide3_clinical_cases_list.append(
                [item for item in guide3_dataset if item['clinical_department'] == clinical_department])
            guide5_clinical_cases_list.append(
                [item for item in guide5_dataset if item['clinical_department'] == clinical_department])

        for clinical_department, inference_clinical_cases, guide3_clinical_cases, guide5_clinical_cases in zip(
                clinical_department_list, inference_clinical_cases_list, guide3_clinical_cases_list,
                guide5_clinical_cases_list):
            print(f'{clinical_department}: {len(inference_clinical_cases)}')
            score_item = calculate_metrics(score_item, clinical_department, inference_clinical_cases,
                                           guide3_clinical_cases, guide5_clinical_cases, gd_dataset, language)
        score_item = calculate_metrics(score_item, 'overall', inference_dataset, guide3_dataset, guide5_dataset,
                                       gd_dataset, language)
        print(score_item)

        score_dict['data'].append(score_item)

    score_dict['data'].sort(key=lambda x: x['overall']['weighted_average'], reverse=True)
    print(score_dict)

    score_save_path = score_dir / Path(score_save_name)
    with open(str(score_save_path), mode='w', encoding='utf-8') as file:
        json.dump(score_dict, file, ensure_ascii=False, indent=2)


if __name__ == '__main__':
    sys.setrecursionlimit(3000)

    language = 'zh'
    gd_zh_load_name = 'disease_diagnosis_ground_truth.json'
    gd_en_load_name = ''
    score_save_name = 'score_weighted_average(2024-05-27).json'

    model_name_list = [
        'baichuan2chat',
        # 'bianque2',
        'bluelmchat',
        'chatglm3',
        'claude3',
        # 'discmedllm',
        'geminipro',
        'gpt3.5',
        'gpt4',
        'huatuogpt2',
        'internlm2chat',
        'pulse',
        'qwenchat',
        'spark3',
        'taiyillm',
        'wingpt2',
        'yichat',
    ]

    clinical_diagnosis_part_list = [
        'preliminary_diagnosis',
        'diagnostic_basis',
        'differential_diagnosis',
        'principal_diagnosis',
        'therapeutic_principle',
        'treatment_plan'
    ]

    model_name_mapping_dict = {
        'baichuan2chat': 'Baichuan2-13B-Chat',  # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
        'bianque2': 'BianQue-2',  # https://huggingface.co/scutcyr/BianQue-2
        'bluelmchat': 'BlueLM-7B-Chat',  # https://huggingface.co/vivo-ai/BlueLM-7B-Chat
        'chatglm3': 'ChatGLM3-6B',  # https://huggingface.co/THUDM/chatglm3-6b
        'claude3': 'Claude-3',  # https://www.anthropic.com/news/claude-3-haiku
        'discmedllm': 'DISC-MedLLM',  # https://huggingface.co/Flmc/DISC-MedLLM
        'geminipro': 'Gemini-Pro',  # https://ai.google.dev/models/gemini
        'gpt3.5': 'GPT-3.5',  # https://platform.openai.com/docs/models/gpt-3-5
        'gpt4': 'GPT-4',  # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
        'huatuogpt2': 'HuatuoGPT2-34B',  # https://huggingface.co/FreedomIntelligence/HuatuoGPT2-34B
        'internlm2chat': 'InternLM2-20B-Chat',  # https://huggingface.co/internlm/internlm2-chat-20b
        'pulse': 'PULSE-20B',  # https://huggingface.co/OpenMEDLab/PULSE-20bv5
        'qwenchat': 'Qwen-72B-Chat',  # https://huggingface.co/Qwen/Qwen-72B-Chat
        'spark3': 'Spark-3',  # https://xinghuo.xfyun.cn/
        'taiyillm': 'Taiyi-LLM',  # https://huggingface.co/DUTIR-BioNLP/Taiyi-LLM
        'wingpt2': 'WiNGPT2-14B-Chat',  # https://huggingface.co/winninghealth/WiNGPT2-14B-Chat
        'yichat': 'Yi-34B-Chat',  # https://huggingface.co/01-ai/Yi-34B-Chat
    }

    institution_name_mapping_dict = {
        'baichuan2chat': 'Baichuan AI',
        'bianque2': 'SCUT-FT',
        'bluelmchat': 'Vivo',
        'chatglm3': 'THUDM & Zhipu AI',
        'claude3': 'Anthropic',
        'discmedllm': 'Fudan-DISC',
        'geminipro': 'Google',
        'gpt3.5': 'OpenAI',
        'gpt4': 'OpenAI',
        'huatuogpt2': 'CUHK-Shenzhen',
        'internlm2chat': 'Shanghai AI Laboratory',
        'pulse': 'Shanghai AI Laboratory',
        'qwenchat': 'Alibaba Cloud',
        'spark3': 'iFLYTEK',
        'taiyillm': 'DUTIR-BioNLP',
        'wingpt2': 'Winning Health',
        'yichat': '01 AI',
    }

    institution_url_mapping_dict = {
        'baichuan2chat': 'https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat',
        'bianque2': 'https://huggingface.co/scutcyr/BianQue-2',
        'bluelmchat': 'https://huggingface.co/vivo-ai/BlueLM-7B-Chat',
        'chatglm3': 'https://huggingface.co/THUDM/chatglm3-6b',
        'claude3': 'https://www.anthropic.com/news/claude-3-haiku',
        'discmedllm': 'https://huggingface.co/Flmc/DISC-MedLLM',
        'geminipro': 'https://ai.google.dev/models/gemini',
        'gpt3.5': 'https://platform.openai.com/docs/models/gpt-3-5',
        'gpt4': 'https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo',
        'huatuogpt2': 'https://huggingface.co/FreedomIntelligence/HuatuoGPT2-34B',
        'internlm2chat': 'https://huggingface.co/internlm/internlm2-chat-20b',
        'pulse': 'https://huggingface.co/OpenMEDLab/PULSE-20bv5',
        'qwenchat': 'https://huggingface.co/Qwen/Qwen-72B-Chat',
        'spark3': 'https://xinghuo.xfyun.cn/',
        'taiyillm': 'https://huggingface.co/DUTIR-BioNLP/Taiyi-LLM',
        'wingpt2': 'https://huggingface.co/winninghealth/WiNGPT2-14B-Chat',
        'yichat': 'https://huggingface.co/01-ai/Yi-34B-Chat',
    }

    clinical_department_zh_to_en_dict = {
        '乳腺外科': 'breast surgical department',
        '产科': 'obstetrics department',
        '儿科': 'pediatrics department',
        '内分泌内科': 'endocrinology department',
        '呼吸内科': 'respiratory medicine department',
        '妇科': 'gynecology department',
        '心脏外科': 'cardiac surgical department',
        '心血管内科': 'cardiovascular medicine department',
        '泌尿外科': 'urinary surgical department',
        '消化内科': 'gastroenterology department',
        '甲状腺外科': 'thyroid surgical department',
        '疝外科': 'hernia surgical department',
        '神经内科': 'neurology department',
        '神经外科': 'neurosurgery department',
        '耳鼻咽喉头颈外科': 'otolaryngology head and neck surgical department',
        '肛门结直肠外科': 'anus and intestine surgical department',
        '肝胆胰外科': 'hepatobiliary and pancreas surgical department',
        '肾内科': 'nephrology department',
        '胃肠外科': 'gastrointestinal surgical department',
        '胸外科': 'thoracic surgical department',
        '血液内科': 'hematology department',
        '血管外科': 'vascular surgical department',
        '骨科': 'orthopedics department',
    }
    clinical_department_zh_list = list(clinical_department_zh_to_en_dict.keys())
    clinical_department_en_list = list(clinical_department_zh_to_en_dict.values())

    inference_dir = Path(__file__).parent.parent / Path('inferences')
    if not inference_dir.is_dir():
        inference_dir.mkdir(parents=True, exist_ok=True)
    guide_dir = Path(__file__).parent.parent / Path('guides')
    if not guide_dir.is_dir():
        guide_dir.mkdir(parents=True, exist_ok=True)
    data_dir = Path(__file__).parent.parent / Path('data')
    if not data_dir.is_dir():
        data_dir.mkdir(parents=True, exist_ok=True)
    score_dir = Path(__file__).parent.parent / Path('scores')
    if not score_dir.is_dir():
        score_dir.mkdir(parents=True, exist_ok=True)

    main()